#!/usr/bin/env python
# coding: utf-8

import json
import os
import numpy as np
from xml.etree.ElementTree import Element, ElementTree

ROOT = "."
LABEL_ROOT = "Label"

# img_suffix = IMG_FILES[0][-3:]


def check_img_exist(root, json_files, image_files):
    for file in json_files:
        file_prefix = file[:-4]
        if file_prefix + 'jpg' not in image_files and file_prefix + 'png' not in image_files:
            os.remove(os.path.join(root, file))
            json_files.remove(file)
            print('no img {}'.format(file_prefix))
    return json_files


# 字典转xml元素数据
def dict_to_xml(tag, d, sub_elem=None):
    elem = Element(tag)
    for key, value in d.items():
        child = Element(key)
        child.text = str(value)
        elem.append(child)
    if sub_elem:
        for e in sub_elem:
            elem.append(e)
    return elem


# 保存json文件
def save_json(data_name, j_data):
    with open(data_name, "w", encoding="utf-8") as fp:
        json.dump(j_data, fp, ensure_ascii=False, indent=4)


# json文件转换，删除imageData内容
def json_convert(root, json_path):
    with open(json_path, encoding='utf-8', mode='r') as f:
        f_read = f.read()  # f_read是字符串

    data = json.loads(f_read)
    data["imageData"] = None
    data['shapes'][0]['label'] = root[2:]
    save_json(json_path, data)


def main():
    i = 0
    image_list = []
    xml_list = []
    # LABEL_ROOT检测创建
    if not os.path.exists(LABEL_ROOT):
        os.makedirs(LABEL_ROOT)
    assert os.path.exists(LABEL_ROOT), "不存在%s目录" % LABEL_ROOT

    label_list = open("label_list", "w+")
    label_list.write("background")
    for root, dirs, files in os.walk(".", topdown=True):
        # print(root)
        # print(dirs)
        json_files = [item for item in files if item[-4:] == 'json']
        image_files = [item for item in files if item[-3:] == 'png' or item[-3:] == 'jpg']
        # print(json_files)
        if len(json_files) != 0:
            label_list.write("\n")
            label_list.write(root[2:])
            # print(root[2:])
            label_dir = os.path.join(LABEL_ROOT, root[2:])
            # print(label_dir)
            # 检测目录是否存在并创建
            if not os.path.exists(label_dir):
                os.makedirs(label_dir)
                print("新建%s目录" % label_dir)
            assert os.path.exists(label_dir), "不存在%s目录" % label_dir
            # 检测json文件对应的图像文件是否存在
            json_files = check_img_exist(root, json_files, image_files)
            # 处理json文件，读取并进行转换，读取到数组内
            for json_file in json_files:
                json_path = os.path.join(root[2:], json_file)
                # json_convert(root, json_path)
                with open(json_path, encoding='utf-8', mode='r') as f:
                    f_read = f.read()  # f_read是字符串
                json_dict = json.loads(f_read)
                # 根据json获取所需json数据
                item = json_dict['shapes'][0]
                [xmin, ymin], [xmax, ymax] = item['points']
                # 生成xml数据
                bbox = dict_to_xml('bndbox',
                                   {'xmin': int(xmin), 'xmax': int(xmax), 'ymin': int(ymin), 'ymax': int(ymax)})
                object_ = dict_to_xml('object', {'name': item["label"], 'difficult': 0}, [bbox])
                anno = dict_to_xml('annotation', {"tmp": "tmp"}, [object_])
                # 根据xml数据生成xml文件格式
                tree = ElementTree(anno)
                xml_file = json_file[:-4] + "xml"
                xml_path = os.path.join(LABEL_ROOT, root[2:], xml_file)
                # 保存对应xml文件
                tree.write(xml_path, encoding='utf-8')
                image_path = json_path[:-4] + "png"
                image_list.append(image_path)
                xml_list.append(xml_path)
                # print(xml_path)
    label_list.close()
    num = len(image_list)
    train_num = int(num * 0.8)
    eval_num = num - train_num
    arr_tmp = np.arange(num)
    np.random.shuffle(arr_tmp)
    count = 0
    train_txt = open("road_train.txt", "w+")
    eval_txt = open("road_eval.txt", "w+")
    for i in arr_tmp:
        if count < train_num:
            train_txt.write(image_list[i])
            train_txt.write(" ")
            train_txt.write(xml_list[i])
            if count < train_num - 1:
                train_txt.write("\n")
        else:
            eval_txt.write(image_list[i])
            eval_txt.write(" ")
            eval_txt.write(xml_list[i])
            if count < num - 1:
                eval_txt.write("\n")
        count = count + 1
    train_txt.close()
    eval_txt.close()
    print(num, train_num, eval_num)


if __name__ == '__main__':
    main()
